"""
Phase 5: Bilateral Gravity — Safe Asset Seeking
=================================================
Tests whether aging origin countries disproportionately send portfolio debt
to safe-rated destinations. Uses bilateral CPIS data from gravity_bilateral.

Output: table5_gravity_safe.md
"""

import sys
from pathlib import Path

import numpy as np
import pandas as pd
import warnings
warnings.filterwarnings('ignore')

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows/safe_assets")
MULTILATERAL_DIR = PROJECT_DIR.parent / "multilateral"
GRAVITY_DIR = PROJECT_DIR.parent / "gravity_bilateral"
sys.path.insert(0, str(MULTILATERAL_DIR / "src"))
from model import PanelGLS

PROCESSED_DIR = PROJECT_DIR / "data" / "processed"
TABLES_DIR = PROJECT_DIR / "output" / "tables"

# Financial centers to exclude in robustness
FINANCIAL_CENTERS = ['LUX', 'IRL', 'HKG', 'SGP', 'CHE', 'BMU', 'CYM',
                     'VGB', 'GGY', 'JEY', 'IMN', 'BHS', 'PAN', 'MLT',
                     'CYP', 'BHR', 'MUS']

OECD_38 = [
    "AUS", "AUT", "BEL", "CAN", "CHL", "COL", "CRI", "CZE", "DNK", "EST",
    "FIN", "FRA", "DEU", "GRC", "HUN", "ISL", "IRL", "ISR", "ITA", "JPN",
    "KOR", "LVA", "LTU", "LUX", "MEX", "NLD", "NZL", "NOR", "POL", "PRT",
    "SVK", "SVN", "ESP", "SWE", "CHE", "TUR", "GBR", "USA",
]


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.10: return '*'
    return ''


def run_gravity(df, dep_var, regressors, label, feature_names=None,
                entity_col='pair_id', time_col='year'):
    """Run PanelGLS gravity model on bilateral data."""
    if dep_var not in df.columns:
        print(f"  [{label}] Dep var {dep_var} missing — skipping")
        return None
    regressors = [r for r in regressors if r in df.columns]
    if not regressors:
        print(f"  [{label}] No regressors available — skipping")
        return None

    sub = df.dropna(subset=[dep_var] + regressors + [entity_col, time_col]).copy()
    if len(sub) < 100:
        print(f"  [{label}] Insufficient obs ({len(sub)}) — skipping")
        return None

    names = feature_names or regressors
    if len(names) != len(regressors):
        names = regressors

    gls = PanelGLS()
    gls.fit(sub[dep_var].values, sub[regressors].values,
            sub[entity_col].values, sub[time_col].values)

    n_pairs = sub[entity_col].nunique()
    print(f"\n  [{label}]  N={gls.n_obs}, pairs={n_pairs}, "
          f"R²={gls.r_squared:.4f}, rho={gls.rho:.3f}")

    results = {
        'label': label,
        'dep_var': dep_var,
        'n_obs': gls.n_obs,
        'n_pairs': n_pairs,
        'r_squared': gls.r_squared,
        'rho': gls.rho,
    }
    for i, name in enumerate(names):
        results[f'coef_{name}'] = gls.beta[i]
        results[f'se_{name}'] = gls.se[i]
        results[f'p_{name}'] = gls.pvalues[i]
        sig = stars(gls.pvalues[i])
        print(f"    {name:<30} {gls.beta[i]:>8.4f} ({gls.se[i]:.4f}) {sig}")

    return results


def main():
    print("=" * 70)
    print("PHASE 5: Bilateral Gravity — Safe Asset Seeking")
    print("=" * 70)

    # ── [1] Load bilateral panel ──
    print("\n[1] Loading bilateral panel ...")
    bp_path = GRAVITY_DIR / "data" / "processed" / "bilateral_panel.csv"
    bp = pd.read_csv(bp_path)
    print(f"  Bilateral panel: {len(bp):,} obs, {bp['pair_id'].nunique():,} pairs")

    # ── [2] Load safe asset panel for ratings ──
    print("\n[2] Adding destination safe issuer status ...")
    safe_panel = pd.read_csv(PROCESSED_DIR / "safe_asset_panel.csv",
                             usecols=['iso3', 'year', 'safe_issuer', 'rating_numeric'])

    # Merge safe status for destination
    bp = bp.merge(safe_panel.rename(columns={
        'iso3': 'iso_d', 'safe_issuer': 'dest_safe', 'rating_numeric': 'dest_rating'
    }), on=['iso_d', 'year'], how='left')
    bp['dest_safe'] = bp['dest_safe'].fillna(0).astype(int)

    # Also for origin
    bp = bp.merge(safe_panel.rename(columns={
        'iso3': 'iso_o', 'safe_issuer': 'orig_safe', 'rating_numeric': 'orig_rating'
    }), on=['iso_o', 'year'], how='left')
    bp['orig_safe'] = bp['orig_safe'].fillna(0).astype(int)

    n_safe_dest = bp[bp['dest_safe'] == 1]['iso_d'].nunique()
    print(f"  Safe destinations: {n_safe_dest} countries")
    print(f"  Pairs with safe dest: {bp[bp['dest_safe']==1].shape[0]:,} / {len(bp):,}")

    # ── [3] Construct interaction terms ──
    print("\n[3] Constructing interaction terms ...")

    # dZ × dest_safe (safe-seeking intensity)
    for z in ['dZ_1', 'dZ_2', 'dZ_3']:
        if z in bp.columns:
            bp[f'{z}_x_dest_safe'] = bp[z] * bp['dest_safe']

    # dZ × KAOPEN_dest × dest_safe (triple)
    if 'kaopen_j' in bp.columns:
        for z in ['dZ_1', 'dZ_2', 'dZ_3']:
            if z in bp.columns:
                bp[f'{z}_x_kaopen_dest_safe'] = bp[z] * bp['kaopen_j'] * bp['dest_safe']

    # Origin aging × dest_safe (alternative specification)
    for z in ['Z_1_i', 'Z_2_i', 'Z_3_i']:
        if z in bp.columns:
            bp[f'{z}_x_dest_safe'] = bp[z] * bp['dest_safe']

    all_results = []

    # Gravity controls
    gravity_vars = ['log_dist', 'contiguity', 'common_lang_official',
                    'colonial_ties', 'log_gdp_product']
    gravity_avail = [v for v in gravity_vars if v in bp.columns]

    demo_diff = ['dZ_1', 'dZ_2', 'dZ_3']
    demo_avail = [v for v in demo_diff if v in bp.columns]

    # ================================================================
    # SECTION A: Baseline Gravity with Safe Destination
    # ================================================================
    print("\n" + "─" * 60)
    print("A. Baseline Gravity with Safe Destination")
    print("─" * 60)

    # M1a: Baseline gravity (replication)
    r = run_gravity(bp, 'log_portfolio_debt', gravity_avail + demo_avail,
                    "M1a: baseline gravity", gravity_avail + demo_avail)
    if r: all_results.append(r)

    # M1b: Add dest_safe
    r = run_gravity(bp, 'log_portfolio_debt',
                    gravity_avail + demo_avail + ['dest_safe'],
                    "M1b: + dest_safe",
                    gravity_avail + demo_avail + ['dest_safe'])
    if r: all_results.append(r)

    # ================================================================
    # SECTION B: Safe-Seeking (dZ × dest_safe)
    # ================================================================
    print("\n" + "─" * 60)
    print("B. Safe-Seeking: dZ × dest_safe")
    print("─" * 60)

    safe_interact = [f'{z}_x_dest_safe' for z in demo_diff if f'{z}_x_dest_safe' in bp.columns]

    # M2a: dZ × dest_safe → portfolio_debt
    r = run_gravity(bp, 'log_portfolio_debt',
                    gravity_avail + demo_avail + ['dest_safe'] + safe_interact,
                    "M2a: dZ×dest_safe → debt",
                    gravity_avail + demo_avail + ['dest_safe'] + safe_interact)
    if r: all_results.append(r)

    # M2b: Same for equity (should be weaker/null — safe assets ≠ equities)
    r = run_gravity(bp, 'log_portfolio_equity',
                    gravity_avail + demo_avail + ['dest_safe'] + safe_interact,
                    "M2b: dZ×dest_safe → equity",
                    gravity_avail + demo_avail + ['dest_safe'] + safe_interact)
    if r: all_results.append(r)

    # M2c: Portfolio total
    r = run_gravity(bp, 'log_portfolio_total',
                    gravity_avail + demo_avail + ['dest_safe'] + safe_interact,
                    "M2c: dZ×dest_safe → total",
                    gravity_avail + demo_avail + ['dest_safe'] + safe_interact)
    if r: all_results.append(r)

    # ================================================================
    # SECTION C: Triple Interaction (dZ × KAOPEN × dest_safe)
    # ================================================================
    print("\n" + "─" * 60)
    print("C. Triple: dZ × KAOPEN_dest × dest_safe")
    print("─" * 60)

    kaopen_interact = ['dZ_1_x_kaopen_j', 'dZ_2_x_kaopen_j', 'dZ_3_x_kaopen_j']
    kaopen_avail = [v for v in kaopen_interact if v in bp.columns]

    triple_interact = [f'{z}_x_kaopen_dest_safe' for z in demo_diff
                       if f'{z}_x_kaopen_dest_safe' in bp.columns]

    if triple_interact:
        all_regs = (gravity_avail + demo_avail + ['dest_safe'] +
                    safe_interact + kaopen_avail + triple_interact)
        r = run_gravity(bp, 'log_portfolio_debt', all_regs,
                        "M3: triple dZ×KA×safe", all_regs)
        if r: all_results.append(r)

    # ================================================================
    # SECTION D: Origin Aging Specification
    # ================================================================
    print("\n" + "─" * 60)
    print("D. Origin Aging × dest_safe (Alternative)")
    print("─" * 60)

    origin_z = ['Z_1_i', 'Z_2_i', 'Z_3_i']
    origin_avail = [v for v in origin_z if v in bp.columns]
    origin_safe = [f'{z}_x_dest_safe' for z in origin_z if f'{z}_x_dest_safe' in bp.columns]

    if origin_safe:
        r = run_gravity(bp, 'log_portfolio_debt',
                        gravity_avail + origin_avail + ['dest_safe'] + origin_safe,
                        "M4: Z_i×dest_safe → debt",
                        gravity_avail + origin_avail + ['dest_safe'] + origin_safe)
        if r: all_results.append(r)

    # ================================================================
    # SECTION E: Robustness
    # ================================================================
    print("\n" + "─" * 60)
    print("E. Robustness Checks")
    print("─" * 60)

    # M5a: Exclude financial centers
    bp_nofc = bp[~bp['iso_o'].isin(FINANCIAL_CENTERS) &
                 ~bp['iso_d'].isin(FINANCIAL_CENTERS)].copy()
    print(f"  Excluding FCs: {len(bp_nofc):,} obs (was {len(bp):,})")

    r = run_gravity(bp_nofc, 'log_portfolio_debt',
                    gravity_avail + demo_avail + ['dest_safe'] + safe_interact,
                    "M5a: excl FC",
                    gravity_avail + demo_avail + ['dest_safe'] + safe_interact)
    if r: all_results.append(r)

    # M5b: OECD origins only
    bp_oecd = bp[bp['iso_o'].isin(OECD_38)].copy()
    r = run_gravity(bp_oecd, 'log_portfolio_debt',
                    gravity_avail + demo_avail + ['dest_safe'] + safe_interact,
                    "M5b: OECD origin",
                    gravity_avail + demo_avail + ['dest_safe'] + safe_interact)
    if r: all_results.append(r)

    # M5c: Non-OECD origins
    bp_non_oecd = bp[~bp['iso_o'].isin(OECD_38)].copy()
    r = run_gravity(bp_non_oecd, 'log_portfolio_debt',
                    gravity_avail + demo_avail + ['dest_safe'] + safe_interact,
                    "M5c: non-OECD origin",
                    gravity_avail + demo_avail + ['dest_safe'] + safe_interact)
    if r: all_results.append(r)

    # M5d: Add rate differential control
    if 'rate_diff_ij' in bp.columns:
        r = run_gravity(bp, 'log_portfolio_debt',
                        gravity_avail + demo_avail + ['dest_safe'] +
                        safe_interact + ['rate_diff_ij'],
                        "M5d: + rate_diff control",
                        gravity_avail + demo_avail + ['dest_safe'] +
                        safe_interact + ['rate_diff_ij'])
        if r: all_results.append(r)

    # ── Build results table ──
    print("\n\nBuilding results table ...")
    build_table(all_results)

    print("\n" + "=" * 70)
    print("Phase 5 complete.")
    print("=" * 70)


def build_table(all_results):
    """Save markdown results table."""
    if not all_results:
        print("  No results to tabulate.")
        return

    key_vars = ['dZ_1', 'dZ_2', 'dZ_3',
                'dest_safe',
                'dZ_1_x_dest_safe', 'dZ_2_x_dest_safe', 'dZ_3_x_dest_safe',
                'Z_1_i', 'Z_1_i_x_dest_safe',
                'dZ_1_x_kaopen_dest_safe',
                'rate_diff_ij']

    md = ["# Table 5: Bilateral Gravity — Safe Asset Seeking\n"]

    md.append("## Model Summary\n")
    md.append("| Model | Dep Var | N | Pairs | R² | ρ |")
    md.append("|---|---|---|---|---|---|")
    for r in all_results:
        md.append(f"| {r['label']} | {r['dep_var']} | {r['n_obs']:,} "
                  f"| {r['n_pairs']:,} | {r['r_squared']:.3f} | {r['rho']:.3f} |")

    md.append("\n## Key Coefficients\n")
    md.append("| Model | Variable | Coef | SE | p-value | Sig |")
    md.append("|---|---|---|---|---|---|")
    for r in all_results:
        for var in key_vars:
            ckey = f'coef_{var}'
            if ckey in r:
                p = r[f'p_{var}']
                md.append(f"| {r['label']} | {var} | {r[ckey]:.4f} "
                          f"| {r[f'se_{var}']:.4f} | {p:.4f} | {stars(p)} |")

    md.append("\n*Gravity controls: log_dist, contiguity, common_lang_official, "
              "colonial_ties, log_gdp_product.*")
    md.append("*PanelGLS with AR(1) correction on bilateral pairs.*")
    md.append("*dest_safe = destination country rated AA- or above (S&P, time-varying).*")

    # Comparison panel: debt vs equity safe-seeking
    md.append("\n## Debt vs Equity Safe-Seeking\n")
    md.append("| Flow Type | dZ₁×dest_safe Coef | p-value |")
    md.append("|-----------|---------------------|---------|")
    for r in all_results:
        if 'coef_dZ_1_x_dest_safe' in r:
            p = r['p_dZ_1_x_dest_safe']
            md.append(f"| {r['dep_var']} | "
                      f"{r['coef_dZ_1_x_dest_safe']:.4f}{stars(p)} | {p:.4f} |")

    out_path = TABLES_DIR / "table5_gravity_safe.md"
    out_path.write_text('\n'.join(md))
    print(f"  Saved: {out_path}")


if __name__ == "__main__":
    main()
